{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Vector Similarity Search"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook demonstrates usage of the vector similarity search, to help the agent select an exact query value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting default log level to \"WARN\".\n",
      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
      "23/10/17 13:50:49 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
     ]
    }
   ],
   "source": [
    "from pyspark_ai import SparkAI\n",
    "\n",
    "# enable vector similarity search by specifying a vector_store_dir location\n",
    "spark_ai = SparkAI(enable_cache=False, vector_store_dir=\"temp/\", verbose=True)\n",
    "spark_ai.activate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "# create a dataframe with inconsistent date formats\n",
    "scores_df = spark_ai._spark.createDataFrame(\n",
    "    [\n",
    "        (\"09-10-2017\", \"24-20\", \"New England Patriots\"),\n",
    "        (\"2017-17-09\", \"27-23\", \"Green Bay Packers\"),\n",
    "        (\"09/24/2017\", \"30-26\", \"Kansas City Chiefs\"),\n",
    "        (\"10-01-2017\", \"26-24\", \"Los Angeles Rams\"),\n",
    "        (\"10/08/2017\", \"23-10\", \"Baltimore Ravens\"),\n",
    "        (\"10/15/17\", \"20-17\", \"Miami Dolphins\"),\n",
    "        (\"10-22-2017\", \"33-31\", \"Chicago Bears\"),\n",
    "        (\"2017-10-29\", \"27-24\", \"Philadelphia Eagles\"),\n",
    "        (\"11-05-2017\", \"28-23\", \"Carolina Panthers\"),\n",
    "        (\"11-12-2017\", \"27-24\", \"New Orleans Saints\")\n",
    "    ],\n",
    "    [\"date\", \"score\", \"winner\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[92mINFO: \u001b[0mCreating temp view for the transform:\n",
      "df.createOrReplaceTempView(\u001b[33m\"\u001b[39;49;00m\u001b[33mspark_ai_temp_view_1345058194\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m)\u001b[37m\u001b[39;49;00m\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3mThought: The keyword 'October 29th' is most similar to the sample values in the `date` column.\n",
      "I need to filter on an exact value from the `date` column, so I will use the tool similar_value to help me choose my filter value.\n",
      "Action: similar_value\n",
      "Action Input: October 29th|date|spark_ai_temp_view_1345058194\u001b[0m\n",
      "Observation: \u001b[38;5;200m\u001b[1;3m2017-10-29\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mThe correct `date` filter should be '2017-10-29' because it is semantically closest to the keyword.\n",
      "I will use the column `date` to filter the rows where its value is '2017-10-29' and then select the `winner`.\n",
      "Action: query_validation\n",
      "Action Input: SELECT `winner` FROM `spark_ai_temp_view_1345058194` WHERE `date` = '2017-10-29'\u001b[0m\n",
      "Observation: \u001b[33;1m\u001b[1;3mOK\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
      "Final Answer: SELECT `winner` FROM `spark_ai_temp_view_1345058194` WHERE `date` = '2017-10-29'\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n",
      "+-------------------+\n",
      "|             winner|\n",
      "+-------------------+\n",
      "|Philadelphia Eagles|\n",
      "+-------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "scores_df.ai.transform(desc=\"Which team won on October 29th?\", cache=False).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[92mINFO: \u001b[0mCreating temp view for the transform:\n",
      "df.createOrReplaceTempView(\u001b[33m\"\u001b[39;49;00m\u001b[33mspark_ai_temp_view_1345058194\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m)\u001b[37m\u001b[39;49;00m\n",
      "\n",
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3mThought: The keyword '23 v 28' is most similar to the sample values in the `score` column.\n",
      "I need to filter on an exact value from the `score` column, so I will use the tool similar_value to help me choose my filter value.\n",
      "Action: similar_value\n",
      "Action Input: 23 v 28|score|spark_ai_temp_view_1345058194\u001b[0m\n",
      "Observation: \u001b[38;5;200m\u001b[1;3m28-23\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mThe correct `score` filter should be '28-23' because it is semantically closest to the keyword.\n",
      "I will use the column `score` to filter the rows where its value is '28-23' and then select the `date` \n",
      "because the question asks for the date when the score was '28-23'.\n",
      "Action: query_validation\n",
      "Action Input: SELECT `date` FROM `spark_ai_temp_view_1345058194` WHERE `score` = '28-23'\u001b[0m\n",
      "Observation: \u001b[33;1m\u001b[1;3mOK\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
      "Final Answer: SELECT `date` FROM `spark_ai_temp_view_1345058194` WHERE `score` = '28-23'\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n",
      "+----------+\n",
      "|      date|\n",
      "+----------+\n",
      "|11-05-2017|\n",
      "+----------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "scores_df.ai.transform(desc=\"When was the score 23 v 28?\", cache=False).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[92mINFO: \u001b[0mCreating temp view for the transform:\n",
      "df.createOrReplaceTempView(\u001b[33m\"\u001b[39;49;00m\u001b[33mspark_ai_temp_view_1345058194\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m)\u001b[37m\u001b[39;49;00m\n",
      "\n",
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3mThought: The column names are non-descriptive, but from the sample values I see that column `winner` contains team names\n",
      "and column `date` contains dates. So, I will filter on column `winner` for 'Chiefs' and select column `date` for the date.\n",
      "I will use = rather than \"like\" in my SQL query because I need an exact match. I'll use the similar_value tool to help\n",
      "me choose an exact filter value for `winner`.\n",
      "Action: similar_value\n",
      "Action Input: Chiefs|winner|spark_ai_temp_view_1345058194\u001b[0m\n",
      "Observation: \u001b[38;5;200m\u001b[1;3mKansas City Chiefs\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mThe correct `winner` filter should be 'Kansas City Chiefs' because it is semantically closest to the keyword.\n",
      "Action: query_validation\n",
      "Action Input: SELECT `date` FROM `spark_ai_temp_view_1345058194` WHERE `winner` = 'Kansas City Chiefs'\u001b[0m\n",
      "Observation: \u001b[33;1m\u001b[1;3mOK\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
      "Final Answer: SELECT `date` FROM `spark_ai_temp_view_1345058194` WHERE `winner` = 'Kansas City Chiefs'\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n",
      "+----------+\n",
      "|      date|\n",
      "+----------+\n",
      "|09/24/2017|\n",
      "+----------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "scores_df.ai.transform(desc=\"When did the Chiefs win?\", cache=False).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a dataframe\n",
    "countries_df = spark_ai._spark.createDataFrame(\n",
    "    [\n",
    "        (\"USA\", \"21.43 trillion USD\", 2.7, \"331 million\", 0.7),\n",
    "        (\"CHN\", \"14.34 trillion USD\", 7.0, \"1.4 billion\", 0.4),\n",
    "        (\"IND\", \"2.87 trillion USD\", 8.0, \"1.3 billion\", 1.2),\n",
    "        (\"BRA\", \"1.87 trillion USD\", -3.5, \"211 million\", 0.7),\n",
    "        (\"CAN\", \"1.64 trillion USD\", 0.7, \"38 million\", 0.8),\n",
    "        (\"AUS\", \"1.37 trillion USD\", 1.0, \"25 million\", 1.1),\n",
    "        (\"RUS\", \"1.47 trillion USD\", -2.0, \"144 million\", 0.04),\n",
    "        (\"FRA\", \"2.71 trillion USD\", 1.3, \"67 million\", 0.3),\n",
    "        (\"GER\", \"3.68 trillion USD\", 0.8, \"83 million\", 0.2),\n",
    "        (\"JPN\", \"5.08 trillion USD\", 1.6, \"126 million\", 0.2),\n",
    "        (\"GBR\", \"2.83 trillion USD\", 0.3, \"67 million\", 0.6),\n",
    "        (\"ZAF\", \"352 billion USD\", 1.2, \"60 million\", 1.2)\n",
    "    ],\n",
    "    [\"country\", \"gdp\", \"gdp growth (%)\", \"population\", \"population growth (%)\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[92mINFO: \u001b[0mCreating temp view for the transform:\n",
      "df.createOrReplaceTempView(\u001b[33m\"\u001b[39;49;00m\u001b[33mspark_ai_temp_view__1575533453\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m)\u001b[37m\u001b[39;49;00m\n",
      "\n",
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3mThought: The keyword '38,000,000' is most similar to the sample values in the `population` column.\n",
      "I need to filter on an exact value from the `population` column, so I will use the tool similar_value to help me choose my filter value.\n",
      "Action: similar_value\n",
      "Action Input: 38,000,000|population|spark_ai_temp_view__1575533453\u001b[0m\n",
      "Observation: \u001b[38;5;200m\u001b[1;3m38 million\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mThe correct `population` filter should be '38 million' because it is semantically closest to the keyword.\n",
      "I will use the column `population` to filter the rows where its value is '38 million' and then select the `gdp` \n",
      "because the question asks for the GDP of the country.\n",
      "Action: query_validation\n",
      "Action Input: SELECT `gdp` FROM `spark_ai_temp_view__1575533453` WHERE `population` = '38 million'\u001b[0m\n",
      "Observation: \u001b[33;1m\u001b[1;3mOK\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
      "Final Answer: SELECT `gdp` FROM `spark_ai_temp_view__1575533453` WHERE `population` = '38 million'\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n",
      "+-----------------+\n",
      "|              gdp|\n",
      "+-----------------+\n",
      "|1.64 trillion USD|\n",
      "+-----------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "countries_df.ai.transform(desc=\"What is the GDP of the country that has 38,000,000 people?\", cache=False).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[92mINFO: \u001b[0mCreating temp view for the transform:\n",
      "df.createOrReplaceTempView(\u001b[33m\"\u001b[39;49;00m\u001b[33mspark_ai_temp_view__1575533453\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m)\u001b[37m\u001b[39;49;00m\n",
      "\n",
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3mThought: The keyword '.7 percent population growth' is most similar to the sample values in the `population growth (%)` column.\n",
      "I need to filter on an exact value from the `population growth (%)` column, so I will use the tool similar_value to help me choose my filter value.\n",
      "Action: similar_value\n",
      "Action Input: .7|population growth (%)|spark_ai_temp_view__1575533453\u001b[0m\n",
      "Observation: \u001b[38;5;200m\u001b[1;3m0.7\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mThe correct `population growth (%)` filter should be '0.7' because it is semantically closest to the keyword.\n",
      "I will use the column `population growth (%)` to filter the rows where its value is '0.7' and then select the AVG(`gdp growth (%)`) \n",
      "because the question asks for the average GDP growth.\n",
      "Action: query_validation\n",
      "Action Input: SELECT AVG(`gdp growth (%)`) FROM `spark_ai_temp_view__1575533453` WHERE `population growth (%)` = 0.7\u001b[0m\n",
      "Observation: \u001b[33;1m\u001b[1;3mOK\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
      "Final Answer: SELECT AVG(`gdp growth (%)`) FROM `spark_ai_temp_view__1575533453` WHERE `population growth (%)` = 0.7\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n",
      "+-------------------+\n",
      "|avg(gdp growth (%))|\n",
      "+-------------------+\n",
      "|-0.3999999999999999|\n",
      "+-------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "countries_df.ai.transform(desc=\"What is the average GDP growth for all countries with .7 percent population growth?\", cache=False).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[92mINFO: \u001b[0mCreating temp view for the transform:\n",
      "df.createOrReplaceTempView(\u001b[33m\"\u001b[39;49;00m\u001b[33mspark_ai_temp_view__1575533453\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m)\u001b[37m\u001b[39;49;00m\n",
      "\n",
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3mThought: The keyword 'Japan' is most similar to the sample values in the `country` column.\n",
      "I need to filter on an exact value from the `country` column, so I will use the tool similar_value to help me choose my filter value.\n",
      "Action: similar_value\n",
      "Action Input: Japan|country|spark_ai_temp_view__1575533453\u001b[0m\n",
      "Observation: \u001b[38;5;200m\u001b[1;3mJPN\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mThe correct `country` filter should be 'JPN' because it is semantically closest to the keyword.\n",
      "I will use the column `country` to filter the rows where its value is 'JPN' and then select the `population growth (%)` \n",
      "because the question asks for the countries with higher population growth than Japan.\n",
      "Action: query_sql_db\n",
      "Action Input: \"SELECT `population growth (%)` FROM spark_ai_temp_view__1575533453 WHERE `country` = 'JPN'\"\u001b[0m\n",
      "Observation: \u001b[36;1m\u001b[1;3m[('0.2',)]\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mThe population growth of Japan is 0.2%. Now I will write a query to find the countries with a higher population growth.\n",
      "Action: query_validation\n",
      "Action Input: SELECT `country` FROM `spark_ai_temp_view__1575533453` WHERE `population growth (%)` > 0.2\u001b[0m\n",
      "Observation: \u001b[33;1m\u001b[1;3mOK\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
      "Final Answer: SELECT `country` FROM `spark_ai_temp_view__1575533453` WHERE `population growth (%)` > 0.2\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n",
      "+-------+\n",
      "|country|\n",
      "+-------+\n",
      "|    USA|\n",
      "|    CHN|\n",
      "|    IND|\n",
      "|    BRA|\n",
      "|    CAN|\n",
      "|    AUS|\n",
      "|    FRA|\n",
      "|    GBR|\n",
      "|    ZAF|\n",
      "+-------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "countries_df.ai.transform(desc=\"Which countries had a higher population growth than Japan?\", cache=False).show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Poetry Env",
   "language": "python",
   "name": "poetry-env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
